# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import itertools
from typing import Any

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
import jax.experimental.pallas.ops.tpu.megablox as mblx
import jax.numpy as jnp
import numpy as np

import hypothesis as hp
import hypothesis.strategies as hps

jax.config.parse_flags_with_absl()

P = jax.sharding.PartitionSpec

partial = functools.partial

hp.settings.register_profile(
    "deterministic",
    database=None,
    derandomize=True,
    deadline=None,
    max_examples=10,
    print_blob=True,
)
hp.settings.load_profile("deterministic")

def seed_strategy() -> hps.SearchStrategy[int]:
  return hps.integers(min_value=0, max_value=4)

@hps.composite
def group_strategy(
    draw: hps.DrawFn,
    max_groups: int = 32,
    max_stride: int = 32,
    min_groups: int = 1,
) -> tuple[int, int]:
  assert max_stride <= max_groups

  # Sample the number of groups owned by each shard.
  group_stride = draw(hps.integers(min_value=1, max_value=max_stride))

  # Sample the number of groups as a multiple of the stride to ensure that we
  # have an equal number of groups per shard. Round down s.t. num_groups <=
  # max_groups.
  num_groups = group_stride * draw(
      hps.integers(min_value=min_groups, max_value=max_groups // group_stride)
  )
  return num_groups, group_stride

@hps.composite
def group_sizes_strategy(
    draw: hps.DrawFn, m: int, num_groups: int
) -> jnp.ndarray:
  # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer
  # sample with replacement so that it's possible to get zero-sized groups. Get
  # 'num_groups - 1' run ends. The final group will end at 'm'.
  ends_no_final = np.sort(
      np.array(
          [
              draw(hps.integers(min_value=0, max_value=m))
              for _ in range(num_groups - 1)
          ],
          dtype=np.int32,
      ),
  )
  ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)])

  # Calculate the run starts by shifting ends 1 to the right. The first run
  # starts at zero.
  starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final])
  return jnp.array(ends - starts, dtype=jnp.int32)

GROUPED_MATMUL_TESTS = (
    (128, 128, 128),  # Small
    (512, 2048, 256),  # Big
    (128, 8, 16),  # Test partial tiles.
)

def random_dense(
    shape: tuple[int, ...],
    key: jax.Array,
    dtype: jnp.dtype,
    limit: int | None = None,
) -> jnp.ndarray:
  if limit is None:
    limit = 1 / np.prod(shape)
  x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit)  # pylint: disable=invalid-unary-operand-type
  return x.astype(jnp.bfloat16).astype(dtype)

def dot(
    lhs: jnp.ndarray,
    rhs: jnp.ndarray,
    transpose_lhs: bool = False,
    transpose_rhs: bool = False,
    preferred_element_type: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
  lhs = jnp.transpose(lhs) if transpose_lhs else lhs
  rhs = jnp.transpose(rhs) if transpose_rhs else rhs
  return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type)

def reference_gmm(
    lhs: jnp.ndarray,
    rhs: jnp.ndarray,
    group_sizes: jnp.ndarray,
    preferred_element_type: jnp.dtype = jnp.float32,
) -> jnp.ndarray:

  start = 0
  out = []
  for i, size in enumerate(group_sizes):
    result = dot(
        lhs[start : start + size, :],
        rhs[i, :, :],
        preferred_element_type=preferred_element_type,
    )

    out.append(result)
    start += group_sizes[i]
  return jnp.concatenate(out, axis=0)

def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]:
  dtypes = [jnp.float32, jnp.bfloat16]

  result = []
  for x in xs:
    for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes):
      result.append(x + dtypes_tuple)
  return tuple(result)

def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]:
  flags = [False, True]
  result = []
  for x in xs:
    for flag in flags:
      result.append(x + (flag,))
  return tuple(result)

def tolerances(
    lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype
) -> tuple[float, float]:
  if (
      lhs_dtype == jnp.bfloat16
      or rhs_dtype == jnp.bfloat16
      or out_dtype == jnp.bfloat16
  ):
    if jtu.is_device_tpu(7):
      return 2e-2, 1e-2
    return 1e-3, 1e-2  # atol, rtol
  return 1e-3, 1e-5  # atol, rtol

# TODO(tgale): Fix errors with strict dtype promotion.
@jtu.with_config(jax_numpy_dtype_promotion="standard")
@jtu.thread_unsafe_test_class()  # hypothesis is not thread safe
class GroupedMatmulTest(jtu.JaxTestCase):

  def setUp(self):
    if not jtu.test_device_matches(["tpu"]):
      self.skipTest("Test requires TPU device.")

    super().setUp()
    self.key = jax.random.PRNGKey(1234)

  def assert_allclose(
      self,
      out: jnp.ndarray,
      expected_out: jnp.ndarray,
      *,
      atol: float = 1e-5,
      rtol: float = 1e-5,
  ):
    self.assertEqual(out.dtype, expected_out.dtype)
    np.testing.assert_allclose(
        out.astype(jnp.float32),
        expected_out.astype(jnp.float32),
        atol=atol,
        rtol=rtol,
    )

  def gmm_test(
      self,
      m: int,
      k: int,
      n: int,
      data: hps.SearchStrategy[hps.DataObject],
      interpret: bool = False,
  ):
    seed = data.draw(seed_strategy())
    num_groups, _ = data.draw(group_strategy(max_stride=1))
    lhs_dtype, rhs_dtype, out_dtype = (
        data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16]))
        for _ in range(3)
    )
    transpose_rhs = data.draw(hps.booleans())

    key = jax.random.key(seed)
    k1, k2 = jax.random.split(key, 2)
    lhs = random_dense((m, k), k1, lhs_dtype, limit=1)
    rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1)
    group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups))

    out, vjpfun = jax.vjp(
        partial(
            mblx.gmm,
            preferred_element_type=out_dtype,
            transpose_rhs=transpose_rhs,
            interpret=interpret,
        ),
        lhs,
        rhs.swapaxes(1, 2) if transpose_rhs else rhs,
        group_sizes,
    )

    def reference_fn(lhs, rhs, group_sizes, preferred_element_type):
      rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs
      return reference_gmm(
          lhs, rhs, group_sizes, preferred_element_type=preferred_element_type
      )

    expected_out, reference_vjpfun = jax.vjp(
        partial(reference_fn, preferred_element_type=out_dtype),
        lhs,
        rhs.swapaxes(1, 2) if transpose_rhs else rhs,
        group_sizes,
    )
    self.assertEqual(out.dtype, out_dtype)
    self.assertEqual(expected_out.dtype, out_dtype)

    atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype)
    self.assert_allclose(out, expected_out, atol=atol, rtol=rtol)

    cotangent = random_dense((m, n), k1, out_dtype, limit=1)
    grad_lhs, grad_rhs, *_ = vjpfun(cotangent)
    expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent)
    self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol)
    self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol)

  @parameterized.parameters(*GROUPED_MATMUL_TESTS)
  @hp.given(hps.data())
  def test_gmm(
      self,
      m: int,
      k: int,
      n: int,
      data: hps.SearchStrategy[hps.DataObject],
  ):
    self.gmm_test(m, k, n, data)

  # NOTE: Run fewer tests with interpret mode. We just want to sanity check that
  # changes do not break running these kernels with interpret=True.
  @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1])
  @hp.given(hps.data())
  def test_gmm_interpret(
      self,
      m: int,
      k: int,
      n: int,
      data: hps.SearchStrategy[hps.DataObject],
  ):
    self.skipTest("interpret mode with dynamic grids is unsupported")
    self.gmm_test(
        m,
        k,
        n,
        data=data,
        interpret=True,
    )

  @parameterized.parameters(*GROUPED_MATMUL_TESTS)
  @hp.given(hps.data())
  def test_gmm_sharded_groups(
      self,
      m: int,
      k: int,
      n: int,
      data: hps.SearchStrategy[hps.DataObject],
  ):
    seed = data.draw(seed_strategy())
    num_groups, group_stride = data.draw(group_strategy())
    lhs_dtype, rhs_dtype, out_dtype = (
        data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16]))
        for _ in range(3)
    )

    key = jax.random.key(seed)
    k1, k2 = jax.random.split(key, 2)
    lhs = random_dense((m, k), k1, lhs_dtype, limit=1)
    rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1)
    group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups))

    out, shard_vjpfun = jax.vjp(
        partial(mblx.gmm, preferred_element_type=out_dtype),
        lhs,
        rhs[0:group_stride],
        group_sizes,
    )
    vjpfuns = [shard_vjpfun]
    for group_offset in range(group_stride, num_groups, group_stride):
      out, shard_vjpfun = jax.vjp(
          lambda lhs, rhs, group_sizes, out: mblx.gmm(
              lhs,
              rhs,
              group_sizes,
              out_dtype,
              group_offset=jnp.array(group_offset, dtype=jnp.int32),  # pylint: disable=cell-var-from-loop
              existing_out=out,
          ),
          lhs,
          rhs[group_offset : group_offset + group_stride],
          group_sizes,
          out,
      )
      vjpfuns.append(shard_vjpfun)

    expected_out, reference_vjpfun = jax.vjp(
        partial(reference_gmm, preferred_element_type=out_dtype),
        lhs,
        rhs,
        group_sizes,
    )
    self.assertEqual(out.dtype, out_dtype)
    self.assertEqual(expected_out.dtype, out_dtype)
    atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype)
    self.assert_allclose(out, expected_out, atol=atol, rtol=rtol)

    cotangent = random_dense((m, n), k1, out_dtype, limit=1)
    shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent)
    grad_lhs = shard_grad_lhs
    grad_rhs = [shard_grad_rhs]
    for i, group_offset in enumerate(
        range(group_stride, num_groups, group_stride)
    ):
      shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent)
      grad_lhs += shard_grad_lhs
      grad_rhs.append(shard_grad_rhs)
    grad_rhs = jnp.concatenate(grad_rhs, axis=0)
    expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent)
    self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol)
    self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol)


if __name__ == "__main__":
  absltest.main(testLoader=jtu.JaxTestLoader())
